{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ECG Rhythm Classification\n",
    "## 1. Create Training Dataset\n",
    "### Sebastian D. Goodfellow, Ph.D."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Noteboook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import 3rd party libraries\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pylab as plt\n",
    "\n",
    "# Import local Libraries\n",
    "sys.path.insert(0, r'C:\\Users\\sebastian goodfellow\\Documents\\code\\deep_ecg')\n",
    "from deepecg.config.config import DATA_DIR\n",
    "from deepecg.training.data.ecg import ECG\n",
    "\n",
    "# Configure Notebook\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load Training Labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save path\n",
    "path_save=os.path.join(DATA_DIR, 'training', 'disc', 'train')\n",
    "\n",
    "# Load training labels\n",
    "labels_train = json.load(open(os.path.join(path_save, 'labels', 'labels.json')))\n",
    "\n",
    "# Label lookup\n",
    "label_lookup = {'N': 0, 'A': 1, 'O': 2, '~': 3}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. MIT-BIH Atrial Fibrillation Database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_duration(waveform, length):\n",
    "    \"\"\"Set duration of ecg waveform.\"\"\"\n",
    "    if len(waveform) > length:\n",
    "        return waveform[0:length]\n",
    "    else:\n",
    "        return waveform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Inputs\n",
    "path_data=os.path.join(DATA_DIR, 'db2')\n",
    "duration=60.\n",
    "fs = 300\n",
    "length = int(duration * fs)\n",
    "\n",
    "# Load labels\n",
    "labels = pd.read_csv(os.path.join(path_data, 'labels', 'labels.csv'))\n",
    "\n",
    "# labels dictionary\n",
    "labels_dict_db2 = dict()\n",
    "\n",
    "# Loop through files\n",
    "for idx, row in labels.iterrows():\n",
    "    \n",
    "    # Load waveform\n",
    "    waveform = np.load(os.path.join(path_data, 'waveforms', row['file_name']))\n",
    "    \n",
    "    try:\n",
    "        # Process ECG waveform\n",
    "        ecg = ECG(file_name=row['file_name'], label=row['train_label'], waveform=waveform, filter_bands=[3, 45], fs=fs)\n",
    "\n",
    "        # Set waveform duration\n",
    "        waveform = set_duration(waveform=ecg.filtered, length=length)\n",
    "\n",
    "        if len(waveform) < length:\n",
    "            # Get remainder\n",
    "            remainder = length - len(waveform)\n",
    "\n",
    "            # Pad waveform\n",
    "            waveform = np.pad(waveform, (int(remainder / 2), remainder - int(remainder / 2)), 'constant', constant_values=0)\n",
    "\n",
    "        # Get label\n",
    "        labels_dict_db2[row['file_name'].split('.')[0]] = label_lookup[row['train_label']]\n",
    "\n",
    "        # Save waveform\n",
    "        np.save(os.path.join(path_save, 'waveforms', row['file_name']), waveform)\n",
    "        \n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. MIT-BIH Normal Sinus Rhythm Database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inputs\n",
    "path_data=os.path.join(DATA_DIR, 'db3')\n",
    "duration=60.\n",
    "fs = 300\n",
    "length = int(duration * fs)\n",
    "\n",
    "# Load labels\n",
    "labels = pd.read_csv(os.path.join(path_data, 'labels', 'labels.csv'))\n",
    "\n",
    "# labels dictionary\n",
    "labels_dict_db3 = dict()\n",
    "\n",
    "# Loop through files\n",
    "for idx, row in labels.iterrows():\n",
    "    \n",
    "    # Load waveform\n",
    "    waveform = np.load(os.path.join(path_data, 'waveforms', row['file_name']))\n",
    "    \n",
    "    try:\n",
    "        # Process ECG waveform\n",
    "        ecg = ECG(file_name=row['file_name'], label=row['train_label'], waveform=waveform, filter_bands=[3, 45], fs=fs)\n",
    "\n",
    "        # Set waveform duration\n",
    "        waveform = set_duration(waveform=ecg.filtered, length=length)\n",
    "\n",
    "        if len(waveform) < length:\n",
    "            # Get remainder\n",
    "            remainder = length - len(waveform)\n",
    "\n",
    "            # Pad waveform\n",
    "            waveform = np.pad(waveform, (int(remainder / 2), remainder - int(remainder / 2)), 'constant', constant_values=0)\n",
    "\n",
    "        # Get label\n",
    "        labels_dict_db3[row['file_name'].split('.')[0]] = label_lookup[row['train_label']]\n",
    "\n",
    "        # Save waveform\n",
    "        np.save(os.path.join(path_save, 'waveforms', row['file_name']), waveform)\n",
    "        \n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Merge Labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge labels\n",
    "labels_train.update(labels_dict_db2)\n",
    "labels_train.update(labels_dict_db3)\n",
    "\n",
    "# Save labels\n",
    "with open(os.path.join(path_save, 'labels', 'labels.json'), 'w') as file:\n",
    "    json.dump(labels_train, file, sort_keys=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
